/*
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package main

import (
	"bytes"
	"context"
	"flag"
	"fmt"
	"go/format"
	"log"
	"os"
	"sort"
	"strings"

	"github.com/aws/aws-sdk-go-v2/config"
	"github.com/aws/aws-sdk-go-v2/service/ec2"
	ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
	"github.com/samber/lo"

	sdk "github.com/aws/karpenter-provider-aws/pkg/aws"
)

const packageHeader = `
package fake

import (
	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/service/ec2"
	ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
	"github.com/samber/lo"
)

// GENERATED FILE. DO NOT EDIT DIRECTLY.
// Update hack/code/instancetype_testdata_gen.go and re-generate to edit
// You can add instance types by adding to the --instance-types CLI flag

`

var instanceTypesStr string
var outFile string

func init() {
	flag.StringVar(&instanceTypesStr, "instance-types", "", "comma-separated list of instance types to auto-generate static test data from")
	flag.StringVar(&outFile, "out-file", "zz_generated.describe_instance_types.go", "file to output the generated data")
	flag.Parse()
}

func main() {
	if err := os.Setenv("AWS_SDK_LOAD_CONFIG", "true"); err != nil {
		log.Fatalf("setting AWS_SDK_LOAD_CONFIG, %s", err)
	}
	if err := os.Setenv("AWS_REGION", "us-east-1"); err != nil {
		log.Fatalf("setting AWS_REGION, %s", err)
	}
	ctx := context.Background()
	cfg := lo.Must(config.LoadDefaultConfig(ctx))
	ec2api := ec2.NewFromConfig(cfg)
	instanceTypes := strings.Split(instanceTypesStr, ",")

	src := &bytes.Buffer{}
	fmt.Fprintln(src, "//go:build !ignore_autogenerated")
	license := lo.Must(os.ReadFile("hack/boilerplate.go.txt"))
	fmt.Fprintln(src, string(license))
	fmt.Fprint(src, packageHeader)
	fmt.Fprintln(src, getDescribeInstanceTypesOutput(ctx, ec2api, instanceTypes))
	fmt.Fprintln(src, getDescribeInstanceTypeOfferingsOutput())

	// Format and print to the file
	formatted, err := format.Source(src.Bytes())
	if err != nil {
		log.Fatalf("formatting generated source, %s", err)
	}
	if err := os.WriteFile(outFile, formatted, 0644); err != nil {
		log.Fatalf("writing output, %s", err)
	}
}

func getDescribeInstanceTypesOutput(ctx context.Context, ec2api sdk.EC2API, instanceTypes []string) string {
	instanceTypeValues := lo.Map(instanceTypes, func(it string, _ int) ec2types.InstanceType {
		return ec2types.InstanceType(it)
	})

	out, err := ec2api.DescribeInstanceTypes(ctx, &ec2.DescribeInstanceTypesInput{
		InstanceTypes: instanceTypeValues,
	})

	if err != nil {
		log.Fatalf("describing instance types, %s", err)
	}
	// Sort them by name so that we get a consistent ordering
	sort.SliceStable(out.InstanceTypes, func(i, j int) bool {
		return out.InstanceTypes[i].InstanceType < out.InstanceTypes[j].InstanceType
	})

	src := &bytes.Buffer{}
	fmt.Fprintln(src, "var defaultDescribeInstanceTypesOutput = &ec2.DescribeInstanceTypesOutput{")
	fmt.Fprintln(src, "InstanceTypes: []ec2types.InstanceTypeInfo{")
	for _, elem := range out.InstanceTypes {
		fmt.Fprintln(src, "{")
		data := getInstanceTypeInfo(elem)
		fmt.Fprintln(src, data)
		fmt.Fprintln(src, "},")
	}
	fmt.Fprintln(src, "},")
	fmt.Fprintln(src, "}")
	return src.String()
}

func getDescribeInstanceTypeOfferingsOutput() string {
	src := &bytes.Buffer{}
	instanceTypeToZones := map[string][]string{
		"m5.large":       {"test-zone-1a", "test-zone-1b", "test-zone-1c", "test-zone-1a-local"},
		"m5.xlarge":      {"test-zone-1a", "test-zone-1b"},
		"m5.2xlarge":     {"test-zone-1a"},
		"m5.4xlarge":     {"test-zone-1a"},
		"m5.8xlarge":     {"test-zone-1a"},
		"p3.8xlarge":     {"test-zone-1a", "test-zone-1b"},
		"dl1.24xlarge":   {"test-zone-1a", "test-zone-1b"},
		"g4dn.8xlarge":   {"test-zone-1a", "test-zone-1b"},
		"g4ad.16xlarge":  {"test-zone-1a", "test-zone-1b"},
		"t3.large":       {"test-zone-1a", "test-zone-1b"},
		"inf2.xlarge":    {"test-zone-1a"},
		"inf2.24xlarge":  {"test-zone-1a"},
		"trn1.2xlarge":   {"test-zone-1a"},
		"c6g.large":      {"test-zone-1a"},
		"m5.metal":       {"test-zone-1a", "test-zone-1b"},
		"m6idn.32xlarge": {"test-zone-1a", "test-zone-1b", "test-zone-1c"},
		"m7i-flex.large": {"test-zone-1a"},
	}

	fmt.Fprintln(src, "var defaultDescribeInstanceTypeOfferingsOutput = &ec2.DescribeInstanceTypeOfferingsOutput{")
	fmt.Fprintln(src, "InstanceTypeOfferings: []ec2types.InstanceTypeOffering{")
	instanceTypes := lo.Keys(instanceTypeToZones)
	sort.Strings(instanceTypes)
	for _, elem := range lo.Flatten(lo.Map(instanceTypes, func(it string, _ int) []lo.Tuple2[string, string] {
		return lo.Map(instanceTypeToZones[it], func(z string, _ int) lo.Tuple2[string, string] { return lo.Tuple2[string, string]{A: it, B: z} })
	})) {
		fmt.Fprintln(src, "{")
		fmt.Fprintf(src, "InstanceType: \"%s\",\n", elem.A)
		fmt.Fprintf(src, "Location: lo.ToPtr(\"%s\"),\n", elem.B)
		fmt.Fprintln(src, "},")
	}
	fmt.Fprintln(src, "},")
	fmt.Fprintln(src, "}")
	return src.String()
}

func getInstanceTypeInfo(info ec2types.InstanceTypeInfo) string {
	src := &bytes.Buffer{}

	fmt.Fprintf(src, "InstanceType: \"%s\",\n", info.InstanceType)
	fmt.Fprintf(src, "SupportedUsageClasses:[]ec2types.UsageClassType{%s},\n", getStringSliceData(info.SupportedUsageClasses))
	fmt.Fprintf(src, "SupportedVirtualizationTypes: []ec2types.VirtualizationType{%s},\n", getStringSliceData(info.SupportedVirtualizationTypes))
	fmt.Fprintf(src, "BurstablePerformanceSupported: aws.Bool(%t),\n", lo.FromPtr(info.BurstablePerformanceSupported))
	fmt.Fprintf(src, "BareMetal: aws.Bool(%t),\n", lo.FromPtr(info.BareMetal))
	fmt.Fprintf(src, "Hypervisor: \"%s\",\n", info.Hypervisor)

	fmt.Fprintf(src, "ProcessorInfo: &ec2types.ProcessorInfo{\n")
	fmt.Fprintf(src, "Manufacturer: aws.String(\"%s\"),\n", lo.FromPtr(info.ProcessorInfo.Manufacturer))
	fmt.Fprintf(src, "SupportedArchitectures: []ec2types.ArchitectureType{%s},\n", getStringSliceData(info.ProcessorInfo.SupportedArchitectures))
	fmt.Fprintf(src, "SustainedClockSpeedInGhz: aws.Float64(%f),\n", lo.FromPtr(info.ProcessorInfo.SustainedClockSpeedInGhz))
	fmt.Fprintf(src, "},\n")
	fmt.Fprintf(src, "VCpuInfo: &ec2types.VCpuInfo{\n")
	fmt.Fprintf(src, "DefaultCores: aws.Int32(%d),\n", lo.FromPtr(info.VCpuInfo.DefaultCores))
	fmt.Fprintf(src, "DefaultVCpus: aws.Int32(%d),\n", lo.FromPtr(info.VCpuInfo.DefaultVCpus))
	fmt.Fprintf(src, "},\n")
	fmt.Fprintf(src, "MemoryInfo: &ec2types.MemoryInfo{\n")
	fmt.Fprintf(src, "SizeInMiB: aws.Int64(%d),\n", lo.FromPtr(info.MemoryInfo.SizeInMiB))
	fmt.Fprintf(src, "},\n")
	if info.EbsInfo != nil {
		fmt.Fprintf(src, "EbsInfo: &ec2types.EbsInfo{\n")
		if info.EbsInfo.EbsOptimizedInfo != nil {
			fmt.Fprintf(src, "EbsOptimizedInfo: &ec2types.EbsOptimizedInfo{\n")
			fmt.Fprintf(src, "BaselineBandwidthInMbps: aws.Int32(%d),\n", lo.FromPtr(info.EbsInfo.EbsOptimizedInfo.BaselineBandwidthInMbps))
			fmt.Fprintf(src, "BaselineIops: aws.Int32(%d),\n", lo.FromPtr(info.EbsInfo.EbsOptimizedInfo.BaselineIops))
			fmt.Fprintf(src, "BaselineThroughputInMBps: aws.Float64(%.2f),\n", lo.FromPtr(info.EbsInfo.EbsOptimizedInfo.BaselineThroughputInMBps))
			fmt.Fprintf(src, "MaximumBandwidthInMbps: aws.Int32(%d),\n", lo.FromPtr(info.EbsInfo.EbsOptimizedInfo.MaximumBandwidthInMbps))
			fmt.Fprintf(src, "MaximumIops: aws.Int32(%d),\n", lo.FromPtr(info.EbsInfo.EbsOptimizedInfo.MaximumIops))
			fmt.Fprintf(src, "MaximumThroughputInMBps: aws.Float64(%.2f),\n", lo.FromPtr(info.EbsInfo.EbsOptimizedInfo.MaximumThroughputInMBps))
			fmt.Fprintf(src, "},\n")
		}
		fmt.Fprintf(src, "EbsOptimizedSupport: \"%s\",\n", info.EbsInfo.EbsOptimizedSupport)
		fmt.Fprintf(src, "EncryptionSupport: \"%s\",\n", info.EbsInfo.EncryptionSupport)
		fmt.Fprintf(src, "NvmeSupport: \"%s\",\n", info.EbsInfo.NvmeSupport)
		fmt.Fprintf(src, "},\n")
	}
	if info.NeuronInfo != nil {
		fmt.Fprintf(src, "NeuronInfo: &ec2types.NeuronInfo{\n")
		fmt.Fprintf(src, "NeuronDevices: []ec2types.NeuronDeviceInfo{\n")
		for _, elem := range info.NeuronInfo.NeuronDevices {
			fmt.Fprintf(src, getNeuronDeviceInfo(elem))
		}
		fmt.Fprintf(src, "},\n")
		fmt.Fprintf(src, "},\n")
	}
	if info.GpuInfo != nil {
		fmt.Fprintf(src, "GpuInfo: &ec2types.GpuInfo{\n")
		fmt.Fprintf(src, "Gpus: []ec2types.GpuDeviceInfo{\n")
		for _, elem := range info.GpuInfo.Gpus {
			fmt.Fprintf(src, getGPUDeviceInfo(elem))
		}
		fmt.Fprintf(src, "},\n")
		fmt.Fprintf(src, "},\n")
	}
	if info.InstanceStorageInfo != nil {
		fmt.Fprintf(src, "InstanceStorageInfo: &ec2types.InstanceStorageInfo{")
		fmt.Fprintf(src, "NvmeSupport: \"%s\",\n", string(info.InstanceStorageInfo.NvmeSupport))
		fmt.Fprintf(src, "TotalSizeInGB: aws.Int64(%d),\n", lo.FromPtr(info.InstanceStorageInfo.TotalSizeInGB))
		fmt.Fprintf(src, "},\n")
	}
	fmt.Fprintf(src, "NetworkInfo: &ec2types.NetworkInfo{\n")
	if info.NetworkInfo.EfaInfo != nil {
		fmt.Fprintf(src, "EfaInfo: &ec2types.EfaInfo{\n")
		fmt.Fprintf(src, "MaximumEfaInterfaces: aws.Int32(%d),\n", lo.FromPtr(info.NetworkInfo.EfaInfo.MaximumEfaInterfaces))
		fmt.Fprintf(src, "},\n")
	}
	fmt.Fprintf(src, "MaximumNetworkInterfaces: aws.Int32(%d),\n", lo.FromPtr(info.NetworkInfo.MaximumNetworkInterfaces))
	fmt.Fprintf(src, "Ipv4AddressesPerInterface: aws.Int32(%d),\n", lo.FromPtr(info.NetworkInfo.Ipv4AddressesPerInterface))
	fmt.Fprintf(src, "EncryptionInTransitSupported: aws.Bool(%t),\n", lo.FromPtr(info.NetworkInfo.EncryptionInTransitSupported))
	fmt.Fprintf(src, "DefaultNetworkCardIndex: aws.Int32(%d),\n", lo.FromPtr(info.NetworkInfo.DefaultNetworkCardIndex))
	fmt.Fprintf(src, "NetworkCards: []ec2types.NetworkCardInfo{\n")
	for _, networkCard := range info.NetworkInfo.NetworkCards {
		fmt.Fprintf(src, getNetworkCardInfo(networkCard))
	}
	fmt.Fprintf(src, "},\n")
	fmt.Fprintf(src, "},\n")
	return src.String()
}

func getNetworkCardInfo(info ec2types.NetworkCardInfo) string {
	src := &bytes.Buffer{}
	fmt.Fprintf(src, "{\n")
	fmt.Fprintf(src, "NetworkCardIndex: aws.Int32(%d),\n", lo.FromPtr(info.NetworkCardIndex))
	fmt.Fprintf(src, "MaximumNetworkInterfaces: aws.Int32(%d),\n", lo.FromPtr(info.MaximumNetworkInterfaces))
	fmt.Fprintf(src, "},\n")
	return src.String()
}

func getNeuronDeviceInfo(info ec2types.NeuronDeviceInfo) string {

	src := &bytes.Buffer{}
	fmt.Fprintf(src, "{\n")
	fmt.Fprintf(src, "Count: aws.Int32(%d),\n", lo.FromPtr(info.Count))
	fmt.Fprintf(src, "Name: aws.String(\"%s\"),\n", lo.FromPtr(info.Name))
	fmt.Fprintf(src, "CoreInfo: &ec2types.NeuronDeviceCoreInfo{\n")
	fmt.Fprintf(src, "Count: aws.Int32(%d),\n", lo.FromPtr(info.CoreInfo.Count))
	fmt.Fprintf(src, "Version: aws.Int32(%d),\n", lo.FromPtr(info.CoreInfo.Version))
	fmt.Fprintf(src, "},\n")
	fmt.Fprintf(src, "MemoryInfo: &ec2types.NeuronDeviceMemoryInfo{\n")
	fmt.Fprintf(src, "SizeInMiB: aws.Int32(%d),\n", lo.FromPtr(info.MemoryInfo.SizeInMiB))
	fmt.Fprintf(src, "},\n")
	fmt.Fprintf(src, "},\n")
	return src.String()
}

func getGPUDeviceInfo(info ec2types.GpuDeviceInfo) string {
	src := &bytes.Buffer{}
	fmt.Fprintf(src, "{\n")
	fmt.Fprintf(src, "Name: aws.String(\"%s\"),\n", lo.FromPtr(info.Name))
	fmt.Fprintf(src, "Manufacturer: aws.String(\"%s\"),\n", lo.FromPtr(info.Manufacturer))
	fmt.Fprintf(src, "Count: aws.Int32(%d),\n", lo.FromPtr(info.Count))
	fmt.Fprintf(src, "MemoryInfo: &ec2types.GpuDeviceMemoryInfo{\n")
	fmt.Fprintf(src, "SizeInMiB: aws.Int32(%d),\n", lo.FromPtr(info.MemoryInfo.SizeInMiB))
	fmt.Fprintf(src, "},\n")
	fmt.Fprintf(src, "},\n")
	return src.String()
}

func getStringSliceData[T ec2types.UsageClassType | ec2types.VirtualizationType | ec2types.ArchitectureType](slice []T) string {
	return strings.Join(lo.Map(slice, func(s T, _ int) string { return fmt.Sprintf(`"%s"`, s) }), ",")
}
